#!/usr/bin/env python
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# note: edited further to also accept swav'style backbones.

import pickle as pkl
import sys
import torch

if __name__ == "__main__":
    input = sys.argv[1]

    obj = torch.load(input, map_location="cpu")
    if "state_dict" in obj.keys():
        obj = obj["state_dict"]
    if "teacher" in obj.keys():
        obj = obj["teacher"]

    newmodel = {}
    if any(['module.encoder_q.' in k for k in obj.keys()]):
        # moco model
        for k, v in obj.items():
            if not k.startswith("module.encoder_q."):
                continue
            old_k = k
            k = k.replace("module.encoder_q.", "")
            if "layer" not in k:
                k = "stem." + k
            for t in [1, 2, 3, 4]:
                k = k.replace("layer{}".format(t), "res{}".format(t + 1))
            for t in [1, 2, 3]:
                k = k.replace("bn{}".format(t), "conv{}.norm".format(t))
            k = k.replace("downsample.0", "shortcut")
            k = k.replace("downsample.1", "shortcut.norm")
            print(old_k, "->", k)
            newmodel[k] = v.numpy()
    else:
        print('assuming model is not MOCO', flush=True)
        obj = {k.replace('module.', ''):v for k,v in obj.items()}
        obj = {k.replace('backbone.', ''):v for k,v in obj.items()}
        for k, v in obj.items():
            # if not k.startswith("module."):
            #     continue
            if ('projection_head' in k) or ('prototypes' in k):
                continue
            old_k = k
            # k = k.replace("module.", "")
            if "layer" not in k:
                k = "stem." + k
            for t in [1, 2, 3, 4]:
                k = k.replace("layer{}".format(t), "res{}".format(t + 1))
            for t in [1, 2, 3]:
                k = k.replace("bn{}".format(t), "conv{}.norm".format(t))
            k = k.replace("downsample.0", "shortcut")
            k = k.replace("downsample.1", "shortcut.norm")
            print(old_k, "->", k)
            newmodel[k] = v.numpy()

    res = {"model": newmodel, "__author__": "MOCO", "matching_heuristics": True}

    with open(sys.argv[2], "wb") as f:
        pkl.dump(res, f)
